{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "949f360e",
   "metadata": {},
   "source": [
    "## Protein Folding with ESMFold and 🤗`transformers`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab50e270",
   "metadata": {},
   "source": [
    "ESMFold ([paper link](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v2)) is a recently released protein folding model from FAIR. Unlike other protein folding models, it does not require external databases or search tools to predict structures, and is up to 60X faster as a result.\n",
    "\n",
    "The port to the HuggingFace `transformers` library is even easier to use, as we've removed the dependency on tools like `openfold` - once you `pip install transformers`, you're ready to use this model! \n",
    "\n",
    "Note that all the code that follows will be running the model **locally**, rather than calling an external API. This means that no rate limiting applies here - you can predict as many structures as your computer can handle. \n",
    "\n",
    "In testing, we found that ESMFold needs about 16-24GB of GPU memory to run well, depending on protein length. This may be too much for the smaller free GPUs on Colab."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2f53405",
   "metadata": {},
   "source": [
    "First step, make sure you're up to date - you'll need the most recent release of `transformers` and `accelerate`! If you want to visualize your predicted protein structure in the notebook, you should also install py3Dmol. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb29483f",
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install --upgrade transformers py3Dmol accelerate"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1dca9819",
   "metadata": {},
   "source": [
    "## Preparing your model and tokenizer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c418e286",
   "metadata": {},
   "source": [
    "Now we load our model and tokenizer. We use the `get_backend` API from accelerate to automatically detect the underlying hardware and move the model to the device."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c200c170",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, EsmForProteinFolding\n",
    "from accelerate.test_utils.testing import get_backend\n",
    "\n",
    "DEVICE, _, _ = get_backend()\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"facebook/esmfold_v1\")\n",
    "model = EsmForProteinFolding.from_pretrained(\"facebook/esmfold_v1\", low_cpu_mem_usage=True)\n",
    "\n",
    "model = model.to(DEVICE)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6f43d78",
   "metadata": {},
   "source": [
    "## Performance optimizations"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc0f0186",
   "metadata": {},
   "source": [
    "Since ESMFold is quite a large model, there are some considerations regarding memory usage and performance.\n",
    "\n",
    "Firstly, we can optionally convert the language model stem to float16 to improve performance and memory usage when running on a modern GPU. This was used during model training, and so should not make the outputs from the rest of the model invalid."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90ee986d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Uncomment to switch the stem to float16\n",
    "model.esm = model.esm.half()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c647c77",
   "metadata": {},
   "source": [
    "Secondly, you can enable TensorFloat32 computation for a general speedup if your hardware supports it. This line has no effect if your hardware doesn't support it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2bd9e11",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "torch.backends.cuda.matmul.allow_tf32 = True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8eefe1b",
   "metadata": {},
   "source": [
    "Finally, we can reduce the 'chunk_size' used in the folding trunk. Smaller chunk sizes use less memory, but have slightly worse performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f8ba985",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Uncomment this line if your GPU memory is 16GB or less, or if you're folding longer (over 600 or so) sequences\n",
    "model.trunk.set_chunk_size(64)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9a26e91",
   "metadata": {},
   "source": [
    "## Folding a single chain"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8752706a",
   "metadata": {},
   "source": [
    "First, we tokenize our input. If you've used `transformers` before, proteins are processed like any other input string. Make sure **not** to add special tokens - ESM was trained with them, but ESMFold was trained without them. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dde34627",
   "metadata": {},
   "outputs": [],
   "source": [
    "# This is the sequence for human GNAT1, because I worked on it when\n",
    "# I was a postdoc and so everyone else has to learn to appreciate it too.\n",
    "# Feel free to substitute your own peptides of interest\n",
    "# Depending on memory constraints you may wish to use shorter sequences.\n",
    "test_protein = \"MGAGASAEEKHSRELEKKLKEDAEKDARTVKLLLLGAGESGKSTIVKQMKIIHQDGYSLEECLEFIAIIYGNTLQSILAIVRAMTTLNIQYGDSARQDDARKLMHMADTIEEGTMPKEMSDIIQRLWKDSGIQACFERASEYQLNDSAGYYLSDLERLVTPGYVPTEQDVLRSRVKTTGIIETQFSFKDLNFRMFDVGGQRSERKKWIHCFEGVTCIIFIAALSAYDMVLVEDDEVNRMHESLHLFNSICNHRYFATTSIVLFLNKKDVFFEKIKKAHLSICFPDYDGPNTYEDAGNYIKVQFLELNMRRDVKEIYSHMTCATDTQNVKFVFDAVTDIIIKENLKDCGLF\"\n",
    "\n",
    "tokenized_input = tokenizer([test_protein], return_tensors=\"pt\", add_special_tokens=False)['input_ids']\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18c00d09",
   "metadata": {},
   "source": [
    "If you're using a GPU, you'll need to move the tokenized data to the GPU now."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0520279",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_input = tokenized_input.to(DEVICE)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89029c8b",
   "metadata": {},
   "source": [
    "With our preparations out of the way, getting your model outputs is as simple as..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "707fad0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "with torch.no_grad():\n",
    "    output = model(tokenized_input)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b34de2f6",
   "metadata": {},
   "source": [
    "Now here's the tricky bit - we convert the model outputs to a PDB file. This will likely be moved to a function in `transformers` in the future, but everything's still quite new, so it lives here for now! This code comes from the original ESMFold repo, and uses some functions from `openfold` that have been ported to `transformers`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c9de19e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein\n",
    "from transformers.models.esm.openfold_utils.feats import atom14_to_atom37\n",
    "\n",
    "def convert_outputs_to_pdb(outputs):\n",
    "    final_atom_positions = atom14_to_atom37(outputs[\"positions\"][-1], outputs)\n",
    "    outputs = {k: v.to(\"cpu\").numpy() for k, v in outputs.items()}\n",
    "    final_atom_positions = final_atom_positions.cpu().numpy()\n",
    "    final_atom_mask = outputs[\"atom37_atom_exists\"]\n",
    "    pdbs = []\n",
    "    for i in range(outputs[\"aatype\"].shape[0]):\n",
    "        aa = outputs[\"aatype\"][i]\n",
    "        pred_pos = final_atom_positions[i]\n",
    "        mask = final_atom_mask[i]\n",
    "        resid = outputs[\"residue_index\"][i] + 1\n",
    "        pred = OFProtein(\n",
    "            aatype=aa,\n",
    "            atom_positions=pred_pos,\n",
    "            atom_mask=mask,\n",
    "            residue_index=resid,\n",
    "            b_factors=outputs[\"plddt\"][i],\n",
    "            chain_index=outputs[\"chain_index\"][i] if \"chain_index\" in outputs else None,\n",
    "        )\n",
    "        pdbs.append(to_pdb(pred))\n",
    "    return pdbs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24613dbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "pdb = convert_outputs_to_pdb(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f23adc4e",
   "metadata": {},
   "source": [
    "Now we have our pdb - can we visualize it?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e094b965",
   "metadata": {},
   "outputs": [],
   "source": [
    "import py3Dmol\n",
    "\n",
    "view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width=800, height=400)\n",
    "view.addModel(\"\".join(pdb), 'pdb')\n",
    "view.setStyle({'model': -1}, {\"cartoon\": {'color': 'spectrum'}})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04b1f814",
   "metadata": {},
   "source": [
    "Looks good! We can colour it differently, though - our model outputs a `plddt` field containing probabilities for each atom, indicating how confident it is in that part of the structure. In the conversion function above we added the `plddt` field in the `b_factors` argument, so it was included in our `pdb` string. Let's use it so that we can see high- and low-confidence areas of the structure visually!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7add289c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The plddt field is scaled from 0-1 on earlier versions of ESMFold but will be updated\n",
    "# to match AlphaFold's scale of 0-100 in future versions.\n",
    "# We check here so that this code will work on either:\n",
    "\n",
    "if torch.max(output['plddt']) <= 1.0:\n",
    "    vmin = 0.5\n",
    "    vmax = 0.95\n",
    "else:\n",
    "    vmin = 50\n",
    "    vmax = 95\n",
    "\n",
    "view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min': vmin,'max': vmax}}})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5546395",
   "metadata": {},
   "source": [
    "Blue indicates high confidence, so that's a pretty high-quality prediction! Not too surprising considering GNAT1 was almost certainly in the training data, but nevertheless good to see. Finally, we can write our PDB string out to a file, which you can download and use in other tools."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7867e6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"output_structure.pdb\", \"w\") as f:\n",
    "    f.write(\"\".join(pdb))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd9cdbee",
   "metadata": {},
   "source": [
    "If you're running this in Colab (and haven't run out of memory by now!) then you can download the file we just created using the file browser interface at the left - the button looks like a little folder icon."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad5c537a",
   "metadata": {},
   "source": [
    "## Folding multiple chains"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1cad2c0a",
   "metadata": {},
   "source": [
    "Many proteins exist as complexes, either as multiple copies of the same peptide (a homopolymer), or a complex of different ones (a heteropolymer). To generate folds for such structures in ESMFold, we use a trick from the paper - we insert a \"linker\" of flexible glycine residues between each chain we want to fold simultaneously, and then we offset the position IDs for each chain from each other, so that the model treats them as being very distant portions of the same long chain. This works quite well, so let's see it in action! We'll use Glucosamine-6-phosphate deaminase (Uniprot: Q9CMF4) from the paper as an example."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74f335dc",
   "metadata": {},
   "source": [
    "First, we define the sequence of the monomer, and the poly-G linker we want to use. Then we stick two copies of the monomer together with the linker in between."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab826ff8",
   "metadata": {},
   "outputs": [],
   "source": [
    "sequence = \"MRLIPLHNVDQVAKWSARYIVDRINQFQPTEARPFVLGLPTGGTPLKTYEALIELYKAGEVSFKHVVTFNMDEYVGLPKEHPESYHSFMYKNFFDHVDIQEKNINILNGNTEDHDAECQRYEEKIKSYGKIHLFMGGVGVDGHIAFNEPASSLSSRTRIKTLTEDTLIANSRFFDNDVNKVPKYALTIGVGTLLDAEEVMILVTGYNKAQALQAAVEGSINHLWTVTALQMHRRAIIVCDEPATQELKVKTVKYFTELEASAIRSVK\"\n",
    "\n",
    "linker = 'G' * 25\n",
    "\n",
    "homodimer_sequence = sequence + linker + sequence"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7abd3f5",
   "metadata": {},
   "source": [
    "Now we tokenize the full homodimer sequence just like we did with the monomer sequence above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dd0534c",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_homodimer = tokenizer([homodimer_sequence], return_tensors=\"pt\", add_special_tokens=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9b66b67",
   "metadata": {},
   "source": [
    "Now here's the tricky bit - we need to tweak the inputs a bit so the model doesn't think this is just a single peptide. The way we do that is by using the `position_ids` input to the model. The `position_ids` input tells the model the position of each amino acid in the input chain. By default, the model assumes that you've passed it one linear, contiguous chain - in other words, if you give it a peptide with 100 amino acids, it will assume the `position_ids` are just `[0, 1, ..., 98, 99]` unless you tell it otherwise.\n",
    "\n",
    "We want to make very clear that the two subunits aren't connected, though, so let's add a large offset to the position IDs of the second chain. The original repo uses 512, so let's stick with that."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3851d39",
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    position_ids = torch.arange(len(homodimer_sequence), dtype=torch.long)\n",
    "    position_ids[len(sequence) + len(linker):] += 512\n",
    "print(position_ids)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9baf48ce",
   "metadata": {},
   "source": [
    "Now we're ready to predict! Let's add our `position_ids` to the tokenized inputs, but make sure to add a singleton batch dimension first to match the other arrays in there! Once that's done we can transfer that dict to the GPU and we're ready to get our folded structure."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "159ee4b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_homodimer['position_ids'] = position_ids.unsqueeze(0)\n",
    "\n",
    "tokenized_homodimer = {key: tensor.to(DEVICE) for key, tensor in tokenized_homodimer.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1f14e1e",
   "metadata": {},
   "source": [
    "Now we compute predictions just like before."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c956feca",
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    output = model(**tokenized_homodimer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3cb716c8",
   "metadata": {},
   "source": [
    "Next, we need to remove the poly-G linker from the output, so we can display the structure as fully independent chains. To do that, we'll alter the `atom37_atom_exists` field in the output. This field indicates, for display purposes, which atoms are present at each residue position. We will simply set all of the atoms for each of the linker residues to 0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fe7031b",
   "metadata": {},
   "outputs": [],
   "source": [
    "linker_mask = torch.tensor([1] * len(sequence) + [0] * len(linker) + [1] * len(sequence))[None, :, None]\n",
    "\n",
    "output['atom37_atom_exists'] = output['atom37_atom_exists'] * linker_mask.to(output['atom37_atom_exists'].device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0162645f",
   "metadata": {},
   "source": [
    "With those output tweaks done, now we can convert the output to PDB and view it as before."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93ac8f58",
   "metadata": {},
   "outputs": [],
   "source": [
    "pdb = convert_outputs_to_pdb(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9beab944",
   "metadata": {},
   "outputs": [],
   "source": [
    "view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width=800, height=400)\n",
    "view.addModel(\"\".join(pdb), 'pdb')\n",
    "\n",
    "# The plddt field is scaled from 0-1 on earlier versions of ESMFold but will be updated\n",
    "# to match AlphaFold's scale of 0-100 in future versions.\n",
    "# We check here so that this code will work on either:\n",
    "\n",
    "if torch.max(output['plddt']) <= 1.0:\n",
    "    vmin = 0.5\n",
    "    vmax = 0.95\n",
    "else:\n",
    "    vmin = 50\n",
    "    vmax = 95\n",
    "\n",
    "view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min': vmin,'max': vmax}}})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "800e8322",
   "metadata": {},
   "source": [
    "And there's our dimer structure! As in the first example, we can now write this structure out to a PDB file and use it in downstream tools:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3ef7389",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"output_structure.pdb\", \"w\") as f:\n",
    "    f.write(\"\".join(pdb))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4245b6d9",
   "metadata": {},
   "source": [
    "**Tip**: If you're trying to predict a multimeric structure and you're getting low-quality outputs, try varying the order of the chains (if it's a heteropolymer) or the length of the linker."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "375ce618",
   "metadata": {},
   "source": [
    "## Bulk predictions "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94795531",
   "metadata": {},
   "source": [
    "Predicting single structures is nice, but the great advantage of running ESMFold locally is the fact that it's extremely fast while still producing highly accurate predictions. This makes it very suitable for proteomics work. Let's see that in action here - we're going to grab a set of monomeric proteins in E. Coli from Uniprot and fold all of them with high accuracy on a single GPU in a couple of minutes (depending on your GPU!)\n",
    "\n",
    "We do this because we can, and to upset any crystallographer friends we may have. First, you may need to install `requests`, `tqdm` and `pandas` if you don't have them already, to handle the data we grab from Uniprot."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1805096",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Uncomment and run this cell to install\n",
    "#! pip install requests pandas tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8bd0d5c3",
   "metadata": {},
   "source": [
    "Next, let's prepare the URL for our Uniprot query."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f19dabb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "\n",
    "uniprot_url = \"https://rest.uniprot.org/uniprotkb/stream?compressed=true&fields=accession%2Csequence&format=tsv&query=%28%28taxonomy_id%3A83333%29%20AND%20%28reviewed%3Atrue%29%20AND%20%28length%3A%5B128%20TO%20512%5D%29%20AND%20%28cc_subunit%3Amonomer%29%29\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59d0f3f9",
   "metadata": {},
   "source": [
    "This uniprot URL might seem mysterious, but it isn't! To get it, we searched for `(taxonomy_id:83333) AND (reviewed:true) AND (length:[128 TO 512]) AND (cc_subunit:monomer)` on UniProt to get all monomeric E.coli proteins of reasonable length, then selected 'Download', and set the format to TSV and the columns to `Sequence`.\n",
    "\n",
    "Once that's done, selecting `Generate URL for API` gives you a URL you can pass to Requests. Alternatively, if you're not on Colab you can just download the data through the web interface and open the file locally."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8fbedb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "uniprot_request = requests.get(uniprot_url)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2cbdce6",
   "metadata": {},
   "source": [
    "To get this data into Pandas, we use a `BytesIO` object, which Pandas will treat like a file. If you downloaded the data as a file you can skip this bit and just pass the filepath directly to `read_csv`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7da0934",
   "metadata": {},
   "outputs": [],
   "source": [
    "from io import BytesIO\n",
    "import pandas\n",
    "\n",
    "bio = BytesIO(uniprot_request.content)\n",
    "\n",
    "df = pandas.read_csv(bio, compression='gzip', sep='\\t')\n",
    "df = df.dropna()  # Remove empty columns, just in case\n",
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9eec628",
   "metadata": {},
   "source": [
    "If you have time, you could process this entire list, giving you folded structures for the entire monomeric proteome of E. Coli. For the sake of this demo, though, let's limit ourselves to 10:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "334e1706",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.iloc[:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d50a31ab",
   "metadata": {},
   "source": [
    "Now let's pull out the sequences and batch-tokenize all of them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c4f3c91",
   "metadata": {},
   "outputs": [],
   "source": [
    "ecoli_tokenized = tokenizer(df.Sequence.tolist(), padding=False, add_special_tokens=False)['input_ids']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bef2d74f",
   "metadata": {},
   "source": [
    "Now we loop over our tokenized data, passing each sequence into our model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e2eb7e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "outputs = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for input_ids in tqdm(ecoli_tokenized):\n",
    "        input_ids = torch.tensor(input_ids, device=DEVICE).unsqueeze(0)\n",
    "        output = model(input_ids)\n",
    "        outputs.append({key: val.cpu() for key, val in output.items()})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9bda2e50",
   "metadata": {},
   "source": [
    "Now we have 10 model outputs, which we can convert to PDB in bulk. If you get an error here, make sure you've run the cell above that defines the convert_outputs_to_pdb function!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "005233a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "pdb_list = [convert_outputs_to_pdb(output) for output in outputs]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2da5185",
   "metadata": {},
   "source": [
    "Let's inspect one of them to see what we got."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1683a505",
   "metadata": {},
   "outputs": [],
   "source": [
    "import py3Dmol\n",
    "\n",
    "view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width=800, height=400)\n",
    "view.addModel(\"\".join(pdb_list[0]), 'pdb')\n",
    "\n",
    "# The plddt field is scaled from 0-1 on earlier versions of ESMFold but will be updated\n",
    "# to match AlphaFold's scale of 0-100 in future versions.\n",
    "# We check here so that this code will work on either:\n",
    "\n",
    "if torch.max(output['plddt']) <= 1.0:\n",
    "    vmin = 0.5\n",
    "    vmax = 0.95\n",
    "else:\n",
    "    vmin = 50\n",
    "    vmax = 95\n",
    "\n",
    "view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min': vmin,'max': vmax}}})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "411cf51a",
   "metadata": {},
   "source": [
    "Looks good to me! Now we can save all of these to disc together."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07076a8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "protein_identifiers = df.Entry.tolist()\n",
    "for identifier, pdb in zip(protein_identifiers, pdb_list):\n",
    "    with open(f\"{identifier}.pdb\", \"w\") as f:\n",
    "        f.write(\"\".join(pdb))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e657713f",
   "metadata": {},
   "source": [
    "If you're on Colab, you can download all of these to your local machine using the file interface on the left."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
